(ns proteins
(:require [tablecloth.api :as tc]
[fastmath.core :as math]
[fastmath.random :as random]
[tech.v3.datatype :as dtype]
[tech.v3.dataset :as dataset]
[tech.v3.tensor :as tensor]
[tech.v3.datatype.functional :as fun]
[aerial.hanami.common :as hc]
[aerial.hanami.templates :as ht]
[scicloj.kindly.v3.kind :as kind]
[scicloj.kindly.v3.api :as kindly]
[scicloj.clay.v2.api :as clay]
[libpython-clj2.python :refer [py. py.. py.-] :as py]
[scicloj.noj.v1.vis :as vis]
[scicloj.noj.v1.vis.python :as vis.python]
[libpython-clj2.require :refer [require-python]]
[util])
(:import java.lang.Math))(require-python '[builtins :as python]
'operator
'[arviz :as az]
'[arviz.style :as az.style]
'[pandas :as pd]
'[matplotlib.pyplot :as plt]
'[numpy :as np]
'[numpy.random :as np.random]
'[pymc :as pm]
'[Bio.PDB.PDBParser]
'[Bio.PDB]
'[Bio.PDB.Polypeptide]
'[pytensor]
'[pytensor.tensor :as pt]
'[math]):ok
(def protein-name1 "7ju5clean")(def protein-name2 "AF-A0A024R7T2-F1-model_v4-clean")(def protein-names [protein-name1 protein-name2])(defn name->path [nam]
(str "data/" nam ".pdb"))(defonce protein-name->pdb-data
(memoize
(fn [protein-name]
(-> protein-name
name->path
slurp))))nil
(defn pdb-view [pdb]
(kind/hiccup
['(fn [pdb]
[:div
{:style {:width "100%"
:height "500px"
:position "relative"}
:ref (fn [el]
(let [config (clj->js
{:backgroundColor "0xffffff"})
viewer (.createViewer js/$3Dmol el #_config)]
(.addModelsAsFrames viewer pdb "pdb")
(.setStyle viewer
(clj->js {})
(clj->js {:stick {:color :spectrum}}))
(.addSphere viewer (clj->js
{:center {:x 0
:y 0
:z 0}
:radius 1
:color "green"}))
(.zoomTo viewer)
(.render viewer)
(.zoom viewer 0.8 1000)))}
;; need to keep this symbol to let Clay infer the necessary dependency
'three-d-mol])
pdb]))(->> protein-names
(map (fn [nam]
[nam
(-> nam
protein-name->pdb-data
pdb-view)]) )
(into {}))(defn shapes-view [shapes]
(kind/hiccup
['(fn [shapes]
[:div
{:style {:width "100%"
:height "500px"
:position "relative"}
:ref (fn [el]
(let [config (clj->js
{:backgroundColor "0xffffff"})
viewer (.createViewer js/$3Dmol el #_config)]
(doseq [[shape-type shape-data] shapes]
(case shape-type
:sphere (.addSphere viewer (clj->js shape-data))
:cylinder (.addCylinder viewer (clj->js shape-data))))
(.zoomTo viewer)
(.render viewer)
(.zoom viewer 0.8 1000)))}
;; need to keep this symbol to let Clay infer the necessary dependency
'three-d-mol])
(vec shapes)]))(-> [[:sphere {:center {:x 0
:y 0
:z 0}
:radius 1
:color "green"}]
[:cylinder {:start {:x 0 :y 10 :z 20}
:end {:x 10 :y 0 :z 30}
:radius 0.5
:fromCap false
:toCap true
:color :teal
:alpha 0.5}]]
shapes-view)(defn extract-coordinates-from-pdb
([protein-name]
(let [filepath (name->path protein-name)
parser (Bio.PDB/PDBParser)
structure (py. parser get_structure protein-name filepath)]
(-> structure
first
((fn [model]
(-> model
(->> (mapcat
(fn [chain]
(->> chain
(filter (fn [residue]
(-> residue
(py. get_resname)
(Bio.PDB.Polypeptide/is_aa :standard true))))
(map (fn [residue]
{:id (-> residue
(py. get_id)
second)
:name (-> residue
(py. get_resname))
:ca-coordinates (try
(-> residue
(util/brackets "CA")
(py. get_coord)
(->> (dtype/->array :float32)))
(catch Exception e nil))}))
(filter :ca-coordinates))))
tc/dataset))))))))(-> protein-name1
extract-coordinates-from-pdb
;; for readability of output:
(tc/update-columns [:ca-coordinates]
(partial map vec)))(defn center-1d [xs]
(fun/- xs
(fun/mean xs)))(defn center-columns [xyzs]
(-> xyzs
(tensor/map-axis center-1d 0)))(defn read-data
([prots]
(read-data prots nil))
([prots {:keys [limit]}]
(let [prots [protein-name1 protein-name2]
[dataset1 dataset2] (->> prots
(map extract-coordinates-from-pdb))
joined-dataset (-> (tc/inner-join dataset1 dataset2 :id)
((if limit
#(tc/head % limit)
identity)))
coords (->> [:ca-coordinates :right.ca-coordinates]
(map (fn [colname]
(-> colname
joined-dataset
tensor/->tensor))))
obs (->> coords
(mapv #(tensor/map-axis % center-1d 0)))
obs-datasets (->> obs
(mapv util/xyz-tensor->dataset))]
{:coords coords
:obs obs
:obs-datasets obs-datasets})))(-> [protein-name1 protein-name2]
(read-data {:limit 4})
:obs-datasets)Compare the datasets visually
(defn xyz-dataset->shapes [dataset options]
(-> dataset
(tc/rows :as-maps)
(->> ((juxt identity rest))
(apply mapv (fn [xyz0 xyz1]
[:cylinder (merge {:start xyz0
:end xyz1}
options)])))))(let [{:keys [obs obs-datasets]} (-> [protein-name1 protein-name2]
read-data)
colors [:purple :orange]
radii [1 1]
view-limit 50]
(->> [obs-datasets colors radii]
(apply mapcat
(fn [dataset color radius]
(-> dataset
(tc/head view-limit)
(xyz-dataset->shapes {:radius radius
:color color}))))
shapes-view))(let [{:keys [obs obs-datasets]} (-> [protein-name1 protein-name2]
read-data)
structures (->> obs
(mapv #(-> %
(tensor/transpose [1 0]))))
view-limit 50
tensor->cljs (fn [tensor]
(-> tensor
(tensor/transpose [1 0])
util/xyz-tensor->dataset
(tc/head view-limit)
util/prep-dataset-for-cljs))]
(->> {:prot1-dataset (-> structures
first
tensor->cljs)
:prot2-dataset (-> structures
second
tensor->cljs)}
(vector '(fn [{:keys [prot1-dataset
prot2-dataset]}]
[plotly
{:data [(-> prot1-dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 1
:marker {:size 3
:color "purple"}}))
(-> prot2-dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 1
:marker {:size 3
:color "orange"}}))]}]))
kind/hiccup))(defn rotate-q [u]
(let [theta1 (-> u
(util/brackets 1)
(operator/mul (* 2 Math/PI)))
theta2 (-> u
(util/brackets 2)
(operator/mul (* 2 Math/PI)))
r1 (-> u
(util/brackets 0)
(->> (operator/sub 1))
pt/sqrt)
r2 (-> u
(util/brackets 0)
pt/sqrt)
w (-> theta2
(pt/cos)
(operator/mul r2))
x (-> theta1
(pt/sin)
(operator/mul r1))
y (-> theta1
(pt/cos)
(operator/mul r1))
z (-> theta2
(pt/sin)
(operator/mul r2))
R00 (operator/sub (operator/add (pt/sqr w)
(pt/sqr x))
(operator/add (pt/sqr y)
(pt/sqr z)))
R11 (operator/sub (operator/add (pt/sqr w)
(pt/sqr y))
(operator/add (pt/sqr x)
(pt/sqr z)))
R22 (operator/sub (operator/add (pt/sqr w)
(pt/sqr z))
(operator/add (pt/sqr x)
(pt/sqr y)))
R01 (operator/mul 2
(operator/sub (operator/mul x y)
(operator/mul w z)))
R02 (operator/mul 2
(operator/add (operator/mul x z)
(operator/mul w y)))
R10 (operator/mul 2
(operator/add (operator/mul x y)
(operator/mul w z)))
R12 (operator/mul 2
(operator/sub (operator/mul y z)
(operator/mul w x)))
R20 (operator/mul 2
(operator/sub (operator/mul x z)
(operator/mul w y)))
R21 (operator/mul 2
(operator/add (operator/mul y z)
(operator/mul w x)))]
(pt/stack [(pt/stack [R00 R01 R02])
(pt/stack [R10 R11 R12])
(pt/stack [R20 R21 R22])])))(defonce model
(memoize
(fn [{:keys [residues-limit tune]}]
(let [{:keys [obs obs-datasets]}
(read-data [protein-name1 protein-name2]
{:limit residues-limit})
structures (->> obs
(mapv #(-> %
(tensor/transpose [1 0]))))
np-structures (->> structures
(mapv util/tensor2d->np-matrix))
shape (-> (obs 0)
dtype/shape
reverse
vec)
[space-dimension n-residues] shape]
(py/with [model (pm/Model)]
(let [M (pm/Cauchy "M"
:alpha 0
:beta 1
:shape shape)
M0 (pm/Deterministic "M0"
(operator/sub
M
(pt/mean M)))
t (pm/Normal "t" :shape [space-dimension]) ; the shift
u (pm/Uniform "u" :shape [space-dimension]) ; randomization of rotation
R (pm/Deterministic "R" (rotate-q u)) ; the rotation matrix
U (pm/HalfNormal "U"
:sigma 0.01 ; TODO: Consider some prior here
:shape [n-residues])
M0_rotated (pm/Deterministic "M0_rotated"
(pt/dot R M0))
X1 (pm/MatrixNormal "X1"
:mu M0
:rowcov (np/eye space-dimension)
:colcov (pt/diag U)
:observed (np-structures 0))
X2 (pm/MatrixNormal "X2"
:mu (-> M0_rotated
;; conjugating with transpose
;; to make broadcasting work
pt/transpose
(operator/add t)
pt/transpose)
:rowcov (np/eye space-dimension)
:colcov (pt/diag U)
:observed (np-structures 1))
M0_adapted (pm/Deterministic "M0_adapted"
(-> (pt/dot R M0)
pt/transpose
(operator/add t)
pt/transpose))
X1_adapted (pm/Deterministic "X1_adapted"
(-> (pt/dot R X1)
pt/transpose
(operator/add t)
pt/transpose))
prot1_adapted (pm/Deterministic "prot1_adapted"
(-> (np-structures 0)
(->> (pt/dot R))
pt/transpose
(operator/add t)
pt/transpose))
prior-predictive-samples (pm/sample_prior_predictive)
idata (pm/sample :chains 1
:draws 200
:tune tune)
posterior-predictive-samples (pm/sample_posterior_predictive
idata)]
{:structures structures
:prior-predictive-samples prior-predictive-samples
:posterior-predictive-samples posterior-predictive-samples
:idata idata}))))))nil
(defn show-results-3dmol [results {:keys [residues-view-limit
samples-view-limit]}]
(let [tensor->dataset (fn [tensor]
(-> tensor
(tensor/transpose [1 0])
util/xyz-tensor->dataset
(tc/head residues-view-limit)))
shape (-> results
:idata
(py.- posterior)
(py.- prot1_adapted)
np/shape)
n-chains (first shape)
n-samples (second shape)
prot1-adapted-datasets (-> results
:idata
(py.- posterior)
(py.- prot1_adapted)
util/py-array->clj
(tensor/slice 1)
(->> (map-indexed
(fn [chain-idx chain-tensor]
(-> chain-tensor
(tensor/slice 1)
(->> (map tensor->dataset)))))
(apply concat)
vec))
prot2-dataset (-> results
:structures
second
tensor->dataset)]
(->> prot1-adapted-datasets
(take samples-view-limit)
(mapcat #(xyz-dataset->shapes
%
{:alpha 0.4
:radius 0.1
:color :purple}))
(concat (-> prot2-dataset
(xyz-dataset->shapes
{:radius 0.3
:color :orange})))
shapes-view)))(defn show-results [results {:keys [residues-view-limit
samples-view-limit]}]
(let [tensor->cljs (fn [tensor aname]
(-> tensor
(tensor/transpose [1 0])
util/xyz-tensor->dataset
(tc/head residues-view-limit)
util/prep-dataset-for-cljs))
shape (-> results
:idata
(py.- posterior)
(py.- prot1_adapted)
np/shape)
n-chains (first shape)
n-samples (second shape)]
(->> {:prot1-adapted-datasets
(-> results
:idata
(py.- posterior)
(py.- prot1_adapted)
util/py-array->clj
(tensor/slice 1)
(->> (map-indexed
(fn [chain-idx chain-tensor]
(-> chain-tensor
(tensor/slice 1)
(->> (map #(tensor->cljs
%
(str "prot1-adapted-chain"
chain-idx)))))))
(apply concat)
vec))
:prot1-chain-idx (->> n-chains
range
(mapcat (fn [chain-idx]
(repeat n-samples chain-idx)))
vec)
:prot2-dataset
(-> results
:structures
second
(tensor->cljs "prot2"))}
(vector '(fn [{:keys [prot1-adapted-datasets
prot1-chain-idx
prot2-dataset]}]
[plotly
{:data (->> prot1-adapted-datasets
(map (fn [dataset]
(-> dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 0.1
:marker {:size 3
:color
(mapv
["blue"
"yellow"
"red"
"green"]
prot1-chain-idx)}}))))
(cons (-> prot2-dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 1
:marker {:size 3
:color "orange"}})))
vec)}]))
kind/hiccup)))(->> [5 15 50 200]
(mapcat (fn [tune]
(let [m (model {:residues-limit 50
:tune tune})]
[{:tune tune}
(show-results m {:residues-view-limit 50
:samples-view-limit 10})
(show-results-3dmol m {:residues-view-limit 50
:samples-view-limit 10})])))):bye:bye